#!/usr/bin/env python

#
# Copyright 2019-2020 NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import argparse
import os.path

from genomeworks import cuda
from genomeworks import cudapoa
from genomeworks.io.utils import read_poa_group_file

"""
Sample program for using CUDAPOA Python API for consensus generation.
"""


def run_cudapoa(groups, msa, print_output):
    """
    Generate consensus for POA groups.

    Args:
        groups : A list of groups (i.e. list of sequences) for which
                 consensus is to be generated.
        msa : Whether to generate MSA or consensus.
        print_output : Print output MSA or consensus.
    """
    # Get avaialble memory information
    free, total = cuda.cuda_get_mem_info(cuda.cuda_get_device())
    gpu_mem_per_batch = 0.9 * free

    # Calculate max bounds for sequence size and sequences per POA.
    max_sequences_per_poa = 0
    max_seq_sz = 0
    for group in groups:
        longest_seq = len(max(group, key=len))
        max_seq_sz = longest_seq if longest_seq > max_seq_sz else max_seq_sz
        seq_in_poa = len(group)
        max_sequences_per_poa = seq_in_poa if seq_in_poa > max_sequences_per_poa else max_sequences_per_poa

    # Create batch
    batch = cudapoa.CudaPoaBatch(max_sequences_per_poa,
                                 max_seq_sz,
                                 gpu_mem_per_batch,
                                 output_type=("msa" if msa else "consensus"),
                                 cuda_banded_alignment=True)

    # Add poa groups to batch
    initial_count = 0
    poa_index = 0
    while (poa_index < len(groups)):
        group = groups[poa_index]

        group_status, seq_status = batch.add_poa_group(group)

        # If group was added and more space is left in batch, continue onto next group.
        if (group_status == 0):
            for seq_index, status in enumerate(seq_status):
                if status != 0:
                    print("Could not add sequence {} to POA {} - error {}".format(seq_index,
                                                                                  poa_index, cudapoa.status_to_str(status)))
            poa_index += 1

        # Otherwise print error.
        # Once batch is full or no groups are left, run POA processing.
        if ((group_status == 1) or ((group_status == 0) and (poa_index == len(groups) - 1))):
            batch.generate_poa()

            if msa:
                msa, msa_status = batch.get_msa()
                if print_output:
                    print(msa)
            else:
                consensus, coverage, con_status = batch.get_consensus()
                for p, status in enumerate(con_status):
                    if status != 0:
                        print("Could not get consensus for POA group {} - {}".format(initial_count +
                                                                                     p, cudapoa.status_to_str(status)))
                if print_output:
                    print(consensus)

            print("---- Processed group {} - {} ----".format(initial_count, poa_index))
            initial_count = poa_index
            batch.reset()
        # In the case where POA group wasn't processed correctly.
        elif (group_status != 0):
            print("Could not add POA group {} to batch - {}".format(poa_index, cudapoa.status_to_str(group_status)))
            poa_index += 1


def parse_args():
    """
    Parse command line arguments.
    """
    parser = argparse.ArgumentParser(
        description="CUDAPOA Python API sample program.")
    parser.add_argument('-m',
                        help="Run MSA generation. By default consensusis generated.",
                        action='store_true')
    parser.add_argument('-p',
                        help="Print output MSA or consensus for each POA group.",
                        action='store_true')
    parser.add_argument('-l',
                        help="Use long or short read sample data.",
                        action='store_true')
    return parser.parse_args()


if __name__ == "__main__":

    # Read cmd line args.
    args = parse_args()

    # Create input dataset.
    cwd = os.path.dirname(os.path.abspath(__file__))
    gw_root = os.path.dirname(os.path.dirname(cwd))
    cudapoa_data_dir = os.path.join(gw_root, "cudapoa", "data")
    sample_windows = os.path.join(cudapoa_data_dir, "sample-bonito.txt" if args.l else "sample-windows.txt")
    groups = read_poa_group_file(sample_windows, 1000)

    # Run CUDAPOA.
    run_cudapoa(groups, args.m, args.p)
